from tianshou.data import Batch
from Network.network_utils import assign_distribution, pytorch_model
from Network.Dists.mask_utils import apply_probabilistic_mask
import torch

MASKING_FORMS = {
    "weighting": 0,
    "relaxed": 1,
    "mixed": 2,
    "hard": 3,
}

MIXING_DISTRIBUTIONS = {
    "weighting": "Identity",
    "relaxed": "RelaxedBernoulli",
    "soft": "RelaxedBernoulli",
    "mixed": "Identity",
    "hard": "Identity",
    "flat": "Identity",
}


def init_dists(args):
    forward_dist = assign_distribution("Gaussian")
    inter_dist = assign_distribution("Bernoulli")
    relaxed_inter_dist = assign_distribution(MIXING_DISTRIBUTIONS[args.inter.masking.masking_form])
    return Batch(forward=forward_dist, inter=inter_dist, relaxed_inter=relaxed_inter_dist)

    # construct the active model
    args.interaction_net.object_dim = self.obj_dim
    self.nextstate_interaction = args.full_inter.nextstate_interaction
    self.multi_instanced = environment.object_instanced[self.name] > 1 if form != "all" else True # an object CANNOT go from instanced to multi instanced
    self.active_model_args, self.passive_model_args, self.interaction_model_args = get_params(self, args, args.interaction_net.net_type in PAIR, self.multi_instanced, self.extractor.total_inter_size, self.extractor.single_object_size)
    args.active_net, args.passive_net = self.active_model_args, self.passive_model_args 
    self.cluster_mode = self.active_model_args.cluster.cluster_mode # uses a mixture of experts implementation, which shoudl return different interaction masks
    self.attention_mode = self.active_model_args.attention_mode # gets the interaction mask from the active model
    self.num_clusters = self.active_model_args.cluster.num_clusters # uses a mixture of experts implementation, which shoudl return different interaction masks
    self.selection_mask = args.full_inter.selection_mask
    self.population_mode = args.EMFAC.is_emfac
    self.is_predict_next_state = args.full_inter.predict_next_state
    self.cap_prob = args.full_inter.cap_probability # TODO: create a testing mode where cap probability is 0
    self.valid_indices = list()

    # set the distributions
    self.dist = assign_distribution("Gaussian") # TODO: only one kind of dist at the moment
    self.dist_temperature = args.full_inter.dist_temperature
    self.mixing = args.full_inter.mixed_interaction# mostly only used for training

    # set the testing module
    self.test = InteractionMaskTesting(args.inter.interaction_testing)

    # set the forward model
    self.active_model_args.mask_attn.inter_dist, self.active_model_args.mask_attn.relaxed_inter_dist, self.active_model_args.mask_attn.dist_temperature, self.active_model_args.mask_attn.test = self.inter_dist, self.relaxed_inter_dist, self.dist_temperature, self.test
    if self.cluster_mode: self.active_model = DiagGaussianForwardPadHotNetwork(self.active_model_args) 
    elif self.population_mode: self.active_model = DiagGaussianForwardMultiMaskNetwork(self.active_model_args) 
    else: self.active_model = DiagGaussianForwardPadMaskNetwork(self.active_model_args)

    # set the passive model
    self.use_active_as_passive = args.full_inter.use_active_as_passive or self.cluster_mode # uses the active model with the one hot as the passive model
    self.lightweight_passive = args.full_inter.lightweight_passive
    # self.passive_model = DiagGaussianForwardPadMaskNetwork(self.passive_model_args) # TODO: comment this out
    self.passive_model = (DiagGaussianForwardNetwork(self.passive_model_args) 
                            if self.lightweight_passive else DiagGaussianForwardPadMaskNetwork(self.passive_model_args) 
                            ) if not self.use_active_as_passive else None


    # construct the interaction model
    self.soft_inter_dist = assign_distribution("RelaxedHot") if self.cluster_mode else assign_distribution("Identity")
    self.hard_inter_dist = assign_distribution("CategoricalHot") if self.cluster_mode else assign_distribution("Identity")
    self.interaction_model = (InteractionSelectionMaskNetwork(self.interaction_model_args) if self.selection_mask else InteractionMaskNetwork(self.interaction_model_args)) if not self.attention_mode else None

    # set the normalization function
    self.norm, self.extractor = normalization, causal_extractor
    self.target_select, self.inter_select = self.extractor.target_select if self.form == "all" else self.extractor.target_selectors[self.name], self.extractor.inter_selector
    # proximity terms
    self.pad_size = normalization.pad_size + normalization.append_size
    self.pos_size = environment.pos_size
    self.object_proximal = None # not sure what I planned to do with this

    # set the masking module to None as a placeholder
    self.mask = None
    self.active_mask = None # also a placeholder
    self.active_select = None

    # set values for proximity calculations
    self.proximity_epsilon, self.position_masks = args.inter.proximity_epsilon, environment.position_masks


def apply_mask(masking_args, dists, inter_mask, soft=True, flat=False, mixed=False, test=False, iscuda=True):
    # generate the interaction mask out of the outputs of the interaction model
    # if the interaction model is in cluster mode, extracts the cluster interaction mask first
    # inter_mask in this case is the selection over cluster modes
    # x is required in this mode
    # does not apply when in attention_mode
    revert_mask = type(inter_mask) != torch.Tensor
    if revert_mask: inter_mask = pytorch_model.wrap(inter_mask, cuda=True)
    mixed = MASKING_FORMS[masking_args.mixed_interaction] == 2 and mixed
    inter_mask = inter_mask - masking_args.cap_probability[0]
    inter_mask[inter_mask < 0] = 0
    inter_mask = inter_mask + masking_args.cap_probability[1]
    inter_mask[inter_mask > 1 - masking_args.cap_probability[0]] = 1 - masking_args.cap_probability[0]
    return apply_probabilistic_mask(inter_mask, 
                                    inter_dist=dists.inter if ((not soft) or (soft and mixed)) else None, 
                                    relaxed_inter_dist=dists.relaxed_inter if soft else None, 
                                    mixed=mixed, test=test if flat else None, 
                                    dist_temperature=masking_args.dist_temperature, 
                                    revert_mask=revert_mask)

def trace_log_probs(num_objects, log_probs, batch, idx=-1):
    # TODO: assumes that in all, num_objects is the number of passive interaction states
    # in one object, 1 is the number of passive interaction states
    if idx==-1: return log_probs[batch.trace.sum(dim=-1).sum(dim=-1) > num_objects].mean()
    else:
        idxes = batch.trace[:,idx].sum(axis=-1) > 1
        if idxes.astype(int).sum() == 0:
            return -10.0
        return log_probs[idxes].mean()